#include "bezier_generator3d.h"
using namespace std;
using namespace Eigen;

// static void MSKAPI printstr(void *handle, MSKCONST char str[]) {
//   printf("%s", str);
// }

int BezierPloyCoeffGeneration3D(
    const vector<Cube> &corridor, const MatrixXd &MQM, const MatrixXd &pos,
    const MatrixXd &vel, const MatrixXd &acc, const double maxVel,
    const double maxAcc, const int traj_order, const double minimize_order,
    const double margin, const bool &isLimitVel, const bool &isLimitAcc,
    double &obj,
    MatrixXd &PolyCoeff) // define the order to which we minimize.   1 --
                         // velocity, 2 -- acceleration, 3 -- jerk, 4 -- snap
{
// whether or not adding extra constraints for ensuring the velocity
// feasibility
#define ENFORCE_VEL isLimitVel
// whether or not adding extra constraints for ensuring the acceleration
// feasibility
#define ENFORCE_ACC isLimitAcc

  double initScale = corridor.front().m_t;
  double lstScale = corridor.back().m_t;
  int segment_num = corridor.size();
  // 多项式系数=多项式阶数+1
  int n_poly = traj_order + 1;
  // 多项式系数=控制点个数
  int s1d1CtrlP_num = n_poly;
  // 三个轴上的控制点的个数
  int s1CtrlP_num = 3 * s1d1CtrlP_num;
  // 起始点约束数量
  int equ_con_s_num = 3 * 3; // p, v, a in x, y, z axis at the start point
  // 终止点约束数量
  int equ_con_e_num = 3 * 3; // p, v, a in x, y, z axis at the end point
  // 连续性约束数量
  int equ_con_continuity_num = 3 * 3 * (segment_num - 1);
  // 总的约束数量
  int equ_con_num = equ_con_s_num + equ_con_e_num + equ_con_continuity_num;
  // p, v, a in x, y, z axis in each
  // segment's joint position
  // 速度约束数量(每一段速度只有 traj_order 个控制点)
  int vel_con_num = 3 * traj_order * segment_num;
  // 加速度约束数量(每一段加速度只有 traj_order-1 个控制点)
  int acc_con_num = 3 * (traj_order - 1) * segment_num;

  if (!ENFORCE_VEL)
    vel_con_num = 0;

  if (!ENFORCE_ACC)
    acc_con_num = 0;

  int high_order_con_num = vel_con_num + acc_con_num;
  // int high_order_con_num = 0; //3 * traj_order * segment_num;
  // 总的约束数量
  int con_num = equ_con_num + high_order_con_num;
  // 总的控制点个数
  int ctrlP_num = segment_num * s1CtrlP_num;
  // x_var用于存储优化过程中计算出的控制点变量
  double x_var[ctrlP_num];
  // primalobj用于存储优化的最终目标值
  double primalobj;
  // MSKboundkeye 是一个枚举类型，用于定义约束或变量的边界类型。
  // 它表示一个约束或变量的取值范围形式，例如是否有上下界、仅上界、仅下界，或者没有边界。

  vector<pair<MSKboundkeye, pair<double, double>>> con_bdk;

  if (ENFORCE_VEL) {
    /***  Stack the bounding value for the linear inequality for the velocity
     * constraints  ***/
    for (int i = 0; i < vel_con_num; i++) {
      pair<MSKboundkeye, pair<double, double>> cb_ie =
          make_pair(MSK_BK_RA, make_pair(-maxVel, +maxVel));
      con_bdk.push_back(cb_ie);
    }
  }

  if (ENFORCE_ACC) {
    /***  Stack the bounding value for the linear inequality for the
     * acceleration constraints  ***/
    for (int i = 0; i < acc_con_num; i++) {
      pair<MSKboundkeye, pair<double, double>> cb_ie =
          make_pair(MSK_BK_RA, make_pair(-maxAcc, maxAcc));
      con_bdk.push_back(cb_ie);
    }
  }

  // ROS_WARN("[Bezier Trajectory] equality bound %d", equ_con_num);
  for (int i = 0; i < equ_con_num; i++) {
    double beq_i;
    if (i < 3)
      beq_i = pos(0, i);
    else if (i >= 3 && i < 6)
      beq_i = vel(0, i - 3);
    else if (i >= 6 && i < 9)
      beq_i = acc(0, i - 6);
    else if (i >= 9 && i < 12)
      beq_i = pos(1, i - 9);
    else if (i >= 12 && i < 15)
      beq_i = vel(1, i - 12);
    else if (i >= 15 && i < 18)
      beq_i = acc(1, i - 15);
    else
      beq_i = 0.0;

    pair<MSKboundkeye, pair<double, double>> cb_eq = make_pair(
        MSK_BK_FX, make_pair(beq_i, beq_i)); // # cb_eq means: constriants
                                             // boundary of equality constrain
    con_bdk.push_back(cb_eq);
  }

  /* ## define a container for control points' boundary and boundkey ## */
  /* ## dataType in one tuple is : boundary type, lower bound, upper bound ## */
  vector<pair<MSKboundkeye, pair<double, double>>> var_bdk;

  for (int k = 0; k < segment_num; k++) {
    Cube cube_ = corridor[k];
    double scale_k = cube_.m_t;

    for (int i = 0; i < 3; i++) {
      for (int j = 0; j < n_poly; j++) {
        pair<MSKboundkeye, pair<double, double>> vb_x;

        double lo_bound, up_bound;
        if (k > 0) {
          lo_bound = (cube_.box[i].first + margin) / scale_k;
          up_bound = (cube_.box[i].second - margin) / scale_k;
        } else {
          lo_bound = (cube_.box[i].first) / scale_k;
          up_bound = (cube_.box[i].second) / scale_k;
        }

        vb_x = make_pair(
            MSK_BK_RA,
            make_pair(lo_bound, up_bound)); // # vb_x means: varialbles boundary
                                            // of unknowns x (Polynomial coeff)

        var_bdk.push_back(vb_x);
      }
    }
  }
  MSKrescodee r;
  MSKint32t j, i;
  MSKenv_t env;
  MSKtask_t task;
  // Create the mosek environment.
  r = MSK_makeenv(&env, NULL);

  // Create the optimization task.
  r = MSK_maketask(env, con_num, ctrlP_num, &task);

  // Parameters used in the optimizer
  // ######################################################################
  // MSK_putintparam (task, MSK_IPAR_OPTIMIZER , MSK_OPTIMIZER_INTPNT );
  // 设置用于优化的线程数
  MSK_putintparam(task, MSK_IPAR_NUM_THREADS, 1);
  // 设置 MOSEK 优化器的双精度参数，用来控制 凸性检查的相对容差
  MSK_putdouparam(task, MSK_DPAR_CHECK_CONVEXITY_REL_TOL, 1e-2);
  // 设置 MOSEK 内点法求解器的双精度参数，用于控制优化问题的 对偶可行性容差
  MSK_putdouparam(task, MSK_DPAR_INTPNT_TOL_DFEAS, 1e-4);
  // 设置 MOSEK 内点法求解器的双精度参数，用于控制优化问题的 原始可行性容差
  MSK_putdouparam(task, MSK_DPAR_INTPNT_TOL_PFEAS, 1e-4);
  // 设置 MOSEK 内点法求解器的双精度参数，用于控制优化问题的 不可行性容差
  MSK_putdouparam(task, MSK_DPAR_INTPNT_TOL_INFEAS, 1e-4);
  // MSK_putdouparam (task, MSK_DPAR_INTPNT_TOL_REL_GAP, 5e-2 );
  // ######################################################################

  // r = MSK_linkfunctotaskstream(task,MSK_STREAM_LOG,NULL,printstr);
  //  Append empty constraints.
  // The constraints will initially have no bounds.
  if (r == MSK_RES_OK)
    r = MSK_appendcons(task, con_num);

  // Append optimizing variables. The variables will initially be fixed at zero
  // (x=0).
  if (r == MSK_RES_OK)
    r = MSK_appendvars(task, ctrlP_num);

  // ROS_WARN("set variables boundary");
  // 设置每一个控制点的位置约束
  for (j = 0; j < ctrlP_num && r == MSK_RES_OK; ++j) {
    if (r == MSK_RES_OK)
      r = MSK_putvarbound(
          task,
          j,                         // Index of variable.
          var_bdk[j].first,          // Bound key.
          var_bdk[j].second.first,   // Numerical value of lower bound.
          var_bdk[j].second.second); // Numerical value of upper bound.
  }

  // Set the bounds on constraints.
  //   for i=1, ...,con_num : blc[i] <= constraint i <= buc[i]
  // 传入速度，加速度的不等式约束
  // 传入起点，终点的p, v, a约束
  // 传入连续性约束
  for (i = 0; i < con_num && r == MSK_RES_OK; i++) {
    r = MSK_putconbound(
        task,
        i,                         // Index of constraint.
        con_bdk[i].first,          // Bound key.
        con_bdk[i].second.first,   // Numerical value of lower bound.
        con_bdk[i].second.second); // Numerical value of upper bound.
  }

  // ROS_WARN("[Bezier Trajectory] Start stacking the Linear Matrix A,
  // inequality part");
  int row_idx = 0;
  // The velocity constraints
  if (ENFORCE_VEL) {
    for (int k = 0; k < segment_num; k++) {
      for (int i = 0; i < 3; i++) { // for x, y, z loop
        for (int p = 0; p < traj_order; p++) {
          // 每个约束行有两个非零元素，表示对于速度约束的线性关系
          //
          // asub[0] 和 asub[1] 是 MOSEK 中约束矩阵 A
          // 的列索引数组，表示约束方程中系数所在的列。此处的 asub
          // 表示两个相邻的多项式系数：当前系数（p）和下一个系数（p+1）
          //
          // aval[0] 和 aval[1] 是系数数组，表示多项式系数的系数值。这里设置的
          // aval[0] = -traj_order 和 aval[1] = traj_order
          // 表示约束方程是两个多项式系数之间的差异，这样就定义了速度的变化。
          int nzi = 2;
          MSKint32t asub[nzi];
          double aval[nzi];

          aval[0] = -1.0 * traj_order;
          aval[1] = 1.0 * traj_order;

          asub[0] = k * s1CtrlP_num + i * s1d1CtrlP_num + p;
          asub[1] = k * s1CtrlP_num + i * s1d1CtrlP_num + p + 1;

          r = MSK_putarow(task, row_idx, nzi, asub, aval);
          row_idx++;
        }
      }
    }
  }

  // The acceleration constraints
  if (ENFORCE_ACC) {
    for (int k = 0; k < segment_num; k++) {
      for (int i = 0; i < 3; i++) {
        for (int p = 0; p < traj_order - 1; p++) {
          int nzi = 3;
          MSKint32t asub[nzi];
          double aval[nzi];

          aval[0] = 1.0 * traj_order * (traj_order - 1) / corridor[k].m_t;
          aval[1] = -2.0 * traj_order * (traj_order - 1) / corridor[k].m_t;
          aval[2] = 1.0 * traj_order * (traj_order - 1) / corridor[k].m_t;
          asub[0] = k * s1CtrlP_num + i * s1d1CtrlP_num + p;
          asub[1] = k * s1CtrlP_num + i * s1d1CtrlP_num + p + 1;
          asub[2] = k * s1CtrlP_num + i * s1d1CtrlP_num + p + 2;

          r = MSK_putarow(task, row_idx, nzi, asub, aval);
          row_idx++;
        }
      }
    }
  }
  /*   Start position  */
  {
    // position :
    for (int i = 0; i < 3; i++) { // loop for x, y, z
      int nzi = 1;
      MSKint32t asub[nzi];
      double aval[nzi];
      aval[0] = 1.0 * initScale;
      asub[0] = i * s1d1CtrlP_num;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
    // velocity :
    for (int i = 0; i < 3; i++) { // loop for x, y, z
      int nzi = 2;
      MSKint32t asub[nzi];
      double aval[nzi];
      aval[0] = -1.0 * traj_order;
      aval[1] = 1.0 * traj_order;
      asub[0] = i * s1d1CtrlP_num;
      asub[1] = i * s1d1CtrlP_num + 1;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
    // acceleration :
    for (int i = 0; i < 3; i++) { // loop for x, y, z
      int nzi = 3;
      MSKint32t asub[nzi];
      double aval[nzi];
      aval[0] = 1.0 * traj_order * (traj_order - 1) / initScale;
      aval[1] = -2.0 * traj_order * (traj_order - 1) / initScale;
      aval[2] = 1.0 * traj_order * (traj_order - 1) / initScale;
      asub[0] = i * s1d1CtrlP_num;
      asub[1] = i * s1d1CtrlP_num + 1;
      asub[2] = i * s1d1CtrlP_num + 2;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
  }

  /*   End position  */
  // ROS_WARN(" end position");
  {
    // position :
    for (int i = 0; i < 3; i++) { // loop for x, y, z
      int nzi = 1;
      MSKint32t asub[nzi];
      double aval[nzi];
      asub[0] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num;
      aval[0] = 1.0 * lstScale;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
    // velocity :
    for (int i = 0; i < 3; i++) {
      int nzi = 2;
      MSKint32t asub[nzi];
      double aval[nzi];
      asub[0] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num - 1;
      asub[1] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num;
      aval[0] = -1.0;
      aval[1] = 1.0;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
    // acceleration :
    for (int i = 0; i < 3; i++) {
      int nzi = 3;
      MSKint32t asub[nzi];
      double aval[nzi];
      asub[0] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num - 2;
      asub[1] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num - 1;
      asub[2] = ctrlP_num - 1 - (2 - i) * s1d1CtrlP_num;
      aval[0] = 1.0 / lstScale;
      aval[1] = -2.0 / lstScale;
      aval[2] = 1.0 / lstScale;
      r = MSK_putarow(task, row_idx, nzi, asub, aval);
      row_idx++;
    }
  }

  /*   joint points  */
  // ROS_WARN(" joint position");
  {
    int sub_shift = 0;
    double val0, val1;
    for (int k = 0; k < (segment_num - 1); k++) {
      double scale_k = corridor[k].m_t;
      double scale_n = corridor[k + 1].m_t;
      // position :
      val0 = scale_k;
      val1 = scale_n;
      for (int i = 0; i < 3; i++) { // loop for x, y, z
        int nzi = 2;
        MSKint32t asub[nzi];
        double aval[nzi];

        // This segment's last control point
        aval[0] = 1.0 * val0;
        asub[0] = sub_shift + (i + 1) * s1d1CtrlP_num - 1;

        // Next segment's first control point
        aval[1] = -1.0 * val1;
        asub[1] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num;
        r = MSK_putarow(task, row_idx, nzi, asub, aval);
        row_idx++;
      }

      for (int i = 0; i < 3; i++) {
        int nzi = 4;
        MSKint32t asub[nzi];
        double aval[nzi];

        // This segment's last velocity control point
        aval[0] = -1.0;
        aval[1] = 1.0;
        asub[0] = sub_shift + (i + 1) * s1d1CtrlP_num - 2;
        asub[1] = sub_shift + (i + 1) * s1d1CtrlP_num - 1;
        // Next segment's first velocity control point
        aval[2] = 1.0;
        aval[3] = -1.0;

        asub[2] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num;
        asub[3] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num + 1;

        r = MSK_putarow(task, row_idx, nzi, asub, aval);
        row_idx++;
      }
      // acceleration :
      val0 = 1.0 / scale_k;
      val1 = 1.0 / scale_n;
      for (int i = 0; i < 3; i++) {
        int nzi = 6;
        MSKint32t asub[nzi];
        double aval[nzi];

        // This segment's last velocity control point
        aval[0] = 1.0 * val0;
        aval[1] = -2.0 * val0;
        aval[2] = 1.0 * val0;
        asub[0] = sub_shift + (i + 1) * s1d1CtrlP_num - 3;
        asub[1] = sub_shift + (i + 1) * s1d1CtrlP_num - 2;
        asub[2] = sub_shift + (i + 1) * s1d1CtrlP_num - 1;
        // Next segment's first velocity control point
        aval[3] = -1.0 * val1;
        aval[4] = 2.0 * val1;
        aval[5] = -1.0 * val1;
        asub[3] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num;
        asub[4] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num + 1;
        asub[5] = sub_shift + s1CtrlP_num + i * s1d1CtrlP_num + 2;

        r = MSK_putarow(task, row_idx, nzi, asub, aval);
        row_idx++;
      }

      sub_shift += s1CtrlP_num;
    }
  }

  // ROS_WARN("[Bezier Trajectory] Start stacking the objective");

  int min_order_l = floor(minimize_order);
  int min_order_u = ceil(minimize_order);

  int NUMQNZ = 0;
  for (int i = 0; i < segment_num; i++) {
    int NUMQ_blk =
        (traj_order + 1); // default minimize the jerk and minimize_order = 3
    NUMQNZ += 3 * NUMQ_blk * (NUMQ_blk + 1) / 2;
  }
  MSKint32t qsubi[NUMQNZ], qsubj[NUMQNZ];
  double qval[NUMQNZ];

  {
    int sub_shift = 0;
    int idx = 0;
    for (int k = 0; k < segment_num; k++) {
      double scale_k = corridor[k].m_t;
      for (int p = 0; p < 3; p++)
        for (int i = 0; i < s1d1CtrlP_num; i++)
          for (int j = 0; j < s1d1CtrlP_num; j++)
            if (i >= j) {
              qsubi[idx] = sub_shift + p * s1d1CtrlP_num + i;
              qsubj[idx] = sub_shift + p * s1d1CtrlP_num + j;
              // qval[idx]  = MQM(i, j) /(double)pow(scale_k, 3);
              if (min_order_l == min_order_u)
                qval[idx] =
                    MQM(i, j) / (double)pow(scale_k, 2 * min_order_u - 3);
              else
                qval[idx] = ((minimize_order - min_order_l) /
                                 (double)pow(scale_k, 2 * min_order_u - 3) +
                             (min_order_u - minimize_order) /
                                 (double)pow(scale_k, 2 * min_order_l - 3)) *
                            MQM(i, j);
              idx++;
            }

      sub_shift += s1CtrlP_num;
    }
  }

  // ros::Time time_end1 = ros::Time::now();

  if (r == MSK_RES_OK)
    r = MSK_putqobj(task, NUMQNZ, qsubi, qsubj, qval);

  if (r == MSK_RES_OK)
    r = MSK_putobjsense(task, MSK_OBJECTIVE_SENSE_MINIMIZE);

  // ros::Time time_opt = ros::Time::now();
  bool solve_ok = false;
  if (r == MSK_RES_OK) {
    // ROS_WARN("Prepare to solve the problem ");
    MSKrescodee trmcode;
    r = MSK_optimizetrm(task, &trmcode);
    MSK_solutionsummary(task, MSK_STREAM_LOG);

    if (r == MSK_RES_OK) {
      MSKsolstae solsta;
      MSK_getsolsta(task, MSK_SOL_ITR, &solsta);

      switch (solsta) {
      case MSK_SOL_STA_OPTIMAL:
      case MSK_SOL_STA_NEAR_OPTIMAL:

        r = MSK_getxx(task,
                      MSK_SOL_ITR, // Request the interior solution.
                      x_var);

        r = MSK_getprimalobj(task, MSK_SOL_ITR, &primalobj);

        obj = primalobj;
        solve_ok = true;

        break;

      case MSK_SOL_STA_DUAL_INFEAS_CER:
      case MSK_SOL_STA_PRIM_INFEAS_CER:
      case MSK_SOL_STA_NEAR_DUAL_INFEAS_CER:
      case MSK_SOL_STA_NEAR_PRIM_INFEAS_CER:
        printf("Primal or dual infeasibility certificate found.\n");
        break;

      case MSK_SOL_STA_UNKNOWN:
        printf("The status of the solution could not be determined.\n");
        // solve_ok = true; // debug
        break;
      default:
        printf("Other solution status.");
        break;
      }
    } else {
      printf("Error while optimizing.\n");
    }
  }

  if (r != MSK_RES_OK) {
    // In case of an error print error code and description.
    char symname[MSK_MAX_STR_LEN];
    char desc[MSK_MAX_STR_LEN];

    printf("An error occurred while optimizing.\n");
    MSK_getcodedesc(r, symname, desc);
    printf("Error %s - '%s'\n", symname, desc);
  }

  MSK_deletetask(&task);
  MSK_deleteenv(&env);

  // ros::Time time_end2 = ros::Time::now();
  // ROS_WARN("time consume in optimize is :");
  // cout << time_end2 - time_end1 << endl;

  if (!solve_ok) {
    ROS_WARN("In solver, falied ");
    return -1;
  }

  VectorXd d_var(ctrlP_num);
  for (int i = 0; i < ctrlP_num; i++)
    d_var(i) = x_var[i];

  PolyCoeff = MatrixXd::Zero(segment_num, 3 * (traj_order + 1));

  int var_shift = 0;
  for (int i = 0; i < segment_num; i++) {
    for (int j = 0; j < 3 * n_poly; j++)
      PolyCoeff(i, j) = d_var(j + var_shift);

    var_shift += 3 * n_poly;
  }

  return 1;
}